{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finetuning a Pytorch Image Classifier with Ray Train\n",
    "\n",
    "<a id=\"try-anyscale-quickstart-pytorch_resnet_finetune\" href=\"https://console.anyscale.com/register/ha?render_flow=ray&utm_source=ray_docs&utm_medium=docs&utm_campaign=pytorch_resnet_finetune\">\n",
    "    <img src=\"../../../_static/img/run-on-anyscale.svg\" alt=\"try-anyscale-quickstart\">\n",
    "</a>\n",
    "<br></br>\n",
    "\n",
    "This example fine-tunes a pre-trained ResNet model with Ray Train. \n",
    "\n",
    "For this example, the network architecture consists of the intermediate layer output of a pre-trained ResNet model, which feeds into a randomly initialized linear layer that outputs classification logits for our new task.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and preprocess finetuning dataset\n",
    "This example is adapted from Pytorch's [Finetuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.\n",
    "We will use *hymenoptera_data* as the finetuning dataset, which contains two classes (bees and ants) and 397 total images (across training and validation). This is a quite small dataset and used only for demonstration purposes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "# To run full example, set SMOKE_TEST as False\n",
    "SMOKE_TEST = True\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset is publicly available [here](https://www.kaggle.com/datasets/ajayrana/hymenoptera-data). Note that it is structured with directory names as the labels. Use `torchvision.datasets.ImageFolder()` to load the images and their corresponding labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, models, transforms\n",
    "import numpy as np\n",
    "\n",
    "# Data augmentation and normalization for training\n",
    "# Just normalization for validation\n",
    "data_transforms = {\n",
    "    \"train\": transforms.Compose(\n",
    "        [\n",
    "            transforms.RandomResizedCrop(224),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
    "        ]\n",
    "    ),\n",
    "    \"val\": transforms.Compose(\n",
    "        [\n",
    "            transforms.Resize(224),\n",
    "            transforms.CenterCrop(224),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
    "        ]\n",
    "    ),\n",
    "}\n",
    "\n",
    "def download_datasets():\n",
    "    os.system(\n",
    "        \"wget https://download.pytorch.org/tutorial/hymenoptera_data.zip >/dev/null 2>&1\"\n",
    "    )\n",
    "    os.system(\"unzip hymenoptera_data.zip >/dev/null 2>&1\")\n",
    "\n",
    "# Download and build torch datasets\n",
    "def build_datasets():\n",
    "    torch_datasets = {}\n",
    "    for split in [\"train\", \"val\"]:\n",
    "        torch_datasets[split] = datasets.ImageFolder(\n",
    "            os.path.join(\"./hymenoptera_data\", split), data_transforms[split]\n",
    "        )\n",
    "    return torch_datasets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "if SMOKE_TEST:\n",
    "    from torch.utils.data import Subset\n",
    "\n",
    "    def build_datasets():\n",
    "        torch_datasets = {}\n",
    "        for split in [\"train\", \"val\"]:\n",
    "            torch_datasets[split] = datasets.ImageFolder(\n",
    "                os.path.join(\"./hymenoptera_data\", split), data_transforms[split]\n",
    "            )\n",
    "            \n",
    "        # Only take a subset for smoke test\n",
    "        for split in [\"train\", \"val\"]:\n",
    "            indices = list(range(100))\n",
    "            torch_datasets[split] = Subset(torch_datasets[split], indices)\n",
    "        return torch_datasets\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Model and Fine-tuning configs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define the training configuration that will be passed into the training loop function later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loop_config = {\n",
    "    \"input_size\": 224,  # Input image size (224 x 224)\n",
    "    \"batch_size\": 32,  # Batch size for training\n",
    "    \"num_epochs\": 10,  # Number of epochs to train for\n",
    "    \"lr\": 0.001,  # Learning Rate\n",
    "    \"momentum\": 0.9,  # SGD optimizer momentum\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define our model. You can either create a model from pre-trained weights or reload the model checkpoint from a previous run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from ray.train import Checkpoint\n",
    "\n",
    "# Option 1: Initialize model with pretrained weights\n",
    "def initialize_model():\n",
    "    # Load pretrained model params\n",
    "    model = models.resnet50(pretrained=True)\n",
    "\n",
    "    # Replace the original classifier with a new Linear layer\n",
    "    num_features = model.fc.in_features\n",
    "    model.fc = nn.Linear(num_features, 2)\n",
    "\n",
    "    # Ensure all params get updated during finetuning\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad = True\n",
    "    return model\n",
    "\n",
    "\n",
    "# Option 2: Initialize model with an Train checkpoint\n",
    "# Replace this with your own uri\n",
    "CHECKPOINT_FROM_S3 = Checkpoint(\n",
    "    path=\"s3://air-example-data/finetune-resnet-checkpoint/TorchTrainer_4f69f_00000_0_2023-02-14_14-04-09/checkpoint_000001/\"\n",
    ")\n",
    "\n",
    "\n",
    "def initialize_model_from_checkpoint(checkpoint: Checkpoint):\n",
    "    with checkpoint.as_directory() as tmpdir:\n",
    "        state_dict = torch.load(os.path.join(tmpdir, \"checkpoint.pt\"))\n",
    "    resnet50 = initialize_model()\n",
    "    resnet50.load_state_dict(state_dict[\"model\"])\n",
    "    return resnet50\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the Training Loop\n",
    "\n",
    "The `train_loop_per_worker` function defines the fine-tuning procedure for each worker.\n",
    "\n",
    "**1. Prepare dataloaders for each worker**:\n",
    "- This tutorial assumes you are using PyTorch's native `torch.utils.data.Dataset` for data input. {meth}`train.torch.prepare_data_loader() <ray.train.torch.prepare_data_loader>`  prepares your dataLoader for distributed execution. You can also use Ray Data for more efficient preprocessing. For more details on using Ray Data for images, see the {doc}`Working with Images </data/working-with-images>` Ray Data user guide.\n",
    "\n",
    "**2. Prepare your model**:\n",
    "- {meth}`train.torch.prepare_model() <ray.train.torch.prepare_model>` prepares the model for distributed training. Under the hood, it converts your torch model to `DistributedDataParallel` model, which synchronize its weights across all workers.\n",
    "\n",
    "**3. Report metrics and checkpoint**:\n",
    "- {meth}`train.report() <ray.train.report>` will report metrics and checkpoints to Ray Train.\n",
    "- Saving checkpoints through {meth}`train.report(metrics, checkpoint=...) <ray.train.report>` will automatically [upload checkpoints to cloud storage](tune-cloud-checkpointing) (if configured), and allow you to easily enable Ray Train worker fault tolerance in the future."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "import ray.train as train\n",
    "from ray.train import Checkpoint\n",
    "\n",
    "\n",
    "\n",
    "def evaluate(logits, labels):\n",
    "    _, preds = torch.max(logits, 1)\n",
    "    corrects = torch.sum(preds == labels).item()\n",
    "    return corrects\n",
    "\n",
    "\n",
    "def train_loop_per_worker(configs):\n",
    "    import warnings\n",
    "\n",
    "    warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "    # Calculate the batch size for a single worker\n",
    "    worker_batch_size = configs[\"batch_size\"] // train.get_context().get_world_size()\n",
    "\n",
    "    # Download dataset once on local rank 0 worker\n",
    "    if train.get_context().get_local_rank() == 0:\n",
    "        download_datasets()\n",
    "    torch.distributed.barrier()\n",
    "\n",
    "    # Build datasets on each worker\n",
    "    torch_datasets = build_datasets()\n",
    "\n",
    "    # Prepare dataloader for each worker\n",
    "    dataloaders = dict()\n",
    "    dataloaders[\"train\"] = DataLoader(\n",
    "        torch_datasets[\"train\"], batch_size=worker_batch_size, shuffle=True\n",
    "    )\n",
    "    dataloaders[\"val\"] = DataLoader(\n",
    "        torch_datasets[\"val\"], batch_size=worker_batch_size, shuffle=False\n",
    "    )\n",
    "\n",
    "    # Distribute\n",
    "    dataloaders[\"train\"] = train.torch.prepare_data_loader(dataloaders[\"train\"])\n",
    "    dataloaders[\"val\"] = train.torch.prepare_data_loader(dataloaders[\"val\"])\n",
    "\n",
    "    device = train.torch.get_device()\n",
    "\n",
    "    # Prepare DDP Model, optimizer, and loss function\n",
    "    model = initialize_model()\n",
    "    model = train.torch.prepare_model(model)\n",
    "\n",
    "    optimizer = optim.SGD(\n",
    "        model.parameters(), lr=configs[\"lr\"], momentum=configs[\"momentum\"]\n",
    "    )\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    # Start training loops\n",
    "    for epoch in range(configs[\"num_epochs\"]):\n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in [\"train\", \"val\"]:\n",
    "            if phase == \"train\":\n",
    "                model.train()  # Set model to training mode\n",
    "            else:\n",
    "                model.eval()  # Set model to evaluate mode\n",
    "\n",
    "            running_loss = 0.0\n",
    "            running_corrects = 0\n",
    "\n",
    "            if train.get_context().get_world_size() > 1:\n",
    "                dataloaders[phase].sampler.set_epoch(epoch)\n",
    "\n",
    "            for inputs, labels in dataloaders[phase]:\n",
    "                inputs = inputs.to(device)\n",
    "                labels = labels.to(device)\n",
    "\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                with torch.set_grad_enabled(phase == \"train\"):\n",
    "                    # Get model outputs and calculate loss\n",
    "                    outputs = model(inputs)\n",
    "                    loss = criterion(outputs, labels)\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == \"train\":\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # calculate statistics\n",
    "                running_loss += loss.item() * inputs.size(0)\n",
    "                running_corrects += evaluate(outputs, labels)\n",
    "\n",
    "            size = len(torch_datasets[phase]) // train.get_context().get_world_size()\n",
    "            epoch_loss = running_loss / size\n",
    "            epoch_acc = running_corrects / size\n",
    "\n",
    "            if train.get_context().get_world_rank() == 0:\n",
    "                print(\n",
    "                    \"Epoch {}-{} Loss: {:.4f} Acc: {:.4f}\".format(\n",
    "                        epoch, phase, epoch_loss, epoch_acc\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            # Report metrics and checkpoint every epoch\n",
    "            if phase == \"val\":\n",
    "                with TemporaryDirectory() as tmpdir:\n",
    "                    state_dict = {\n",
    "                        \"epoch\": epoch,\n",
    "                        \"model\": model.module.state_dict(),\n",
    "                        \"optimizer_state_dict\": optimizer.state_dict(),\n",
    "                    }\n",
    "                    torch.save(state_dict, os.path.join(tmpdir, \"checkpoint.pt\"))\n",
    "                    train.report(\n",
    "                        metrics={\"loss\": epoch_loss, \"acc\": epoch_acc},\n",
    "                        checkpoint=Checkpoint.from_directory(tmpdir),\n",
    "                    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, setup the TorchTrainer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.train.torch import TorchTrainer\n",
    "from ray.train import ScalingConfig, RunConfig, CheckpointConfig\n",
    "\n",
    "# Scale out model training across 4 GPUs.\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=4, use_gpu=True, resources_per_worker={\"CPU\": 1, \"GPU\": 1}\n",
    ")\n",
    "\n",
    "# Save the latest checkpoint\n",
    "checkpoint_config = CheckpointConfig(num_to_keep=1)\n",
    "\n",
    "# Set experiment name and checkpoint configs\n",
    "run_config = RunConfig(\n",
    "    name=\"finetune-resnet\",\n",
    "    storage_path=\"/tmp/ray_results\",\n",
    "    checkpoint_config=checkpoint_config,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "if SMOKE_TEST:\n",
    "    scaling_config = ScalingConfig(\n",
    "        num_workers=8, use_gpu=False, resources_per_worker={\"CPU\": 1}\n",
    "    )\n",
    "    train_loop_config[\"num_epochs\"] = 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = TorchTrainer(\n",
    "    train_loop_per_worker=train_loop_per_worker,\n",
    "    train_loop_config=train_loop_config,\n",
    "    scaling_config=scaling_config,\n",
    "    run_config=run_config,\n",
    ")\n",
    "\n",
    "result = trainer.fit()\n",
    "print(result)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the checkpoint for prediction:\n",
    "\n",
    " \n",
    " The metadata and checkpoints have already been saved in the `storage_path` specified in TorchTrainer:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to load the trained model and evaluate it on test data. The best model parameters have been saved in `log_dir`. We can load the resulting checkpoint from our fine-tuning run using the previously defined `initialize_model_from_checkpoint()` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = initialize_model_from_checkpoint(result.checkpoint)\n",
    "device = torch.device(\"cuda\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "if SMOKE_TEST:\n",
    "    device = torch.device(\"cpu\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, define a simple evaluation loop and check the performance of the checkpoint model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.934640522875817\n"
     ]
    }
   ],
   "source": [
    "model = model.to(device)\n",
    "model.eval()\n",
    "\n",
    "download_datasets()\n",
    "torch_datasets = build_datasets()\n",
    "dataloader = DataLoader(torch_datasets[\"val\"], batch_size=32, num_workers=4)\n",
    "corrects = 0\n",
    "for inputs, labels in dataloader:\n",
    "    inputs = inputs.to(device)\n",
    "    labels = labels.to(device)\n",
    "    preds = model(inputs)\n",
    "    corrects += evaluate(preds, labels)\n",
    "\n",
    "print(\"Accuracy: \", corrects / len(dataloader.dataset))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orphan": true,
  "vscode": {
   "interpreter": {
    "hash": "a8c1140d108077f4faeb76b2438f85e4ed675f93d004359552883616a1acd54c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
